[ckpt] refactor: Consolidate fused expert mappings and fix MTP inference#2685
[ckpt] refactor: Consolidate fused expert mappings and fix MTP inference#2685
Conversation
Introduce FusedExpertMapping and FusedGatedExpertMapping in param_mapping.py to handle many-to-one / one-to-many expert weight conversions generically. This eliminates duplicated maybe_modify_converted_hf_weight overrides and hf_weights_cache from GPT-OSS, GLM-4.5, GLM-4.5V, and Qwen3-VL bridges (-502 / +307 lines). Also fixes two pre-existing bugs: - GLM-4.5 MTP mappings used stale 'transformer_layer' instead of 'mtp_model_layer', causing missing-mapping warnings - hf_to_megatron_generate_text.py set mtp_num_layers=None which crashed MTP-enabled models; replaced with m.mtp_process=False Signed-off-by: Yu Yao <yaoyu.094@gmail.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Made-with: Cursor
|
/ok to test ff3705b |
📝 WalkthroughWalkthroughThis PR introduces fused expert mapping classes and grouped export accumulation logic for optimized MoE weight conversion, replaces legacy per-expert mapping implementations across multiple model bridges (Qwen, GPT-OSS, GLM) with the new fused variants, and updates MTP inference handling by conditionally disabling the mtp_process attribute. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant ModelBridge as MegatronModelBridge
participant Accum as Grouped<br/>Accumulator
participant GrpBuf as GroupedBuffers<br/>(Tensor Cache)
participant Output as Merged<br/>Tensor Dict
Client->>ModelBridge: stream_weights_hf_to_megatron<br/>(with grouped_export mapping)
ModelBridge->>ModelBridge: Detect is_grouped_export=True
ModelBridge->>GrpBuf: Initialize grouped_buffers[group_key]
loop For Each Expert in Group
ModelBridge->>ModelBridge: Load HF weight slice
ModelBridge->>Accum: _accumulate_grouped_export<br/>(expert_idx, weight)
Accum->>GrpBuf: Store per-expert weight<br/>at global expert index
end
Accum->>Accum: All experts collected?
alt Yes - Group Complete
Accum->>Accum: Stack/merge expert<br/>tensors into single tensor
Accum->>Accum: Optionally transpose<br/>to match shape
Accum->>Output: Return merged dict
Output->>Client: Yield merged tensor
else No - Still Accumulating
Accum->>Client: Yield None (continue)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/megatron/bridge/models/glm/glm45_bridge.py (1)
218-300:⚠️ Potential issue | 🟠 MajorAdd dual-prefix support for MTP layer mappings to handle both Megatron-Core naming conventions.
The MTP mappings currently hard-code only
mtp_model_layerin the explicit QKV/MLP/expert mappings (lines 250, 256, 262, 267, 277, 284, 295, 300) and in the generatedAutoMappingentries at line 218. Megatron-Core may expose the MTP submodule astransformer_layerinstead, which will leave MTP weights unmapped for those checkpoints. Follow the pattern inmimo_bridge.pyby iterating over both prefixes to ensure compatibility across different Megatron-Core versions.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/models/glm/glm45_bridge.py` around lines 218 - 300, The MTP mappings only use the "mtp_model_layer" prefix causing missed mappings when Megatron exposes the submodule as "transformer_layer"; update the mapping construction to loop over both prefixes (e.g., prefixes = ["mtp_model_layer", "transformer_layer"]) and add mappings for each prefix so every place that currently constructs megatron_param with "mtp_model_layer" (including the AutoMapping entries and the specialized mappings: QKVMapping, GatedMLPMapping, GLMExpertGateUpProjMapping, GLMExpertDownProjMapping, and the existing AutoMapping for experts) is duplicated/created for the alternate "transformer_layer" prefix; follow the pattern used in mimo_bridge.py to generate entries for both prefixes and append them to mapping_list.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/conversion/hf_to_megatron_generate_text.py`:
- Around line 171-172: The current change only flips the model instance flag
m.mtp_process, but you must also disable MTP at the config level and clear
mixed-precision scaling to avoid NCCL hangs: when you see the block that checks
hasattr(m, "mtp_process") and sets m.mtp_process = False, also set
m.config.mtp_num_layers = None (or 0 if config expects an int) and set
m.grad_scale_func = None, using attribute existence checks before assignment to
avoid attribute errors; update the same function/section that handles
m.mtp_process so all three changes are applied together.
In `@src/megatron/bridge/models/glm/glm_moe_mappings.py`:
- Around line 21-23: Module currently only re-exports GLMExpertDownProjMapping
causing import-time failure where GLMExpertGateUpProjMapping is expected; add a
matching re-export for the gate mapping by importing the appropriate symbol from
megatron.bridge.models.conversion.param_mapping and aliasing it to
GLMExpertGateUpProjMapping (mirror the existing pattern used for
GLMExpertDownProjMapping), so downstream code that imports and instantiates
GLMExpertGateUpProjMapping will succeed.
In `@src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py`:
- Around line 121-130: The quantized-path returning _dequantize_mxfp4(blocks,
scales) doesn't mirror the direct-tensor branch's transpose for 3D expert
weights, causing expert tensors to keep HF layout; update the branch handling
blocks_key/scales_key so that after calling _dequantize_mxfp4 you detect if
hf_param contains ".mlp.experts." and the returned tensor has ndim == 3, then
transpose the last two axes (i.e., swap -1 and -2) before returning; locate this
logic around the hf_param string branch that references hf_state_dict,
_dequantize_mxfp4, and the ".mlp.experts." selector to apply the fix.
---
Outside diff comments:
In `@src/megatron/bridge/models/glm/glm45_bridge.py`:
- Around line 218-300: The MTP mappings only use the "mtp_model_layer" prefix
causing missed mappings when Megatron exposes the submodule as
"transformer_layer"; update the mapping construction to loop over both prefixes
(e.g., prefixes = ["mtp_model_layer", "transformer_layer"]) and add mappings for
each prefix so every place that currently constructs megatron_param with
"mtp_model_layer" (including the AutoMapping entries and the specialized
mappings: QKVMapping, GatedMLPMapping, GLMExpertGateUpProjMapping,
GLMExpertDownProjMapping, and the existing AutoMapping for experts) is
duplicated/created for the alternate "transformer_layer" prefix; follow the
pattern used in mimo_bridge.py to generate entries for both prefixes and append
them to mapping_list.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 32a098f8-3a4b-4da8-b6f7-927e1570c4c4
📒 Files selected for processing (10)
examples/conversion/hf_to_megatron_generate_text.pysrc/megatron/bridge/models/conversion/__init__.pysrc/megatron/bridge/models/conversion/model_bridge.pysrc/megatron/bridge/models/conversion/param_mapping.pysrc/megatron/bridge/models/glm/glm45_bridge.pysrc/megatron/bridge/models/glm/glm_moe_mappings.pysrc/megatron/bridge/models/glm_vl/glm_45v_bridge.pysrc/megatron/bridge/models/gpt_oss/gpt_oss_bridge.pysrc/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.pysrc/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py
💤 Files with no reviewable changes (1)
- src/megatron/bridge/models/glm_vl/glm_45v_bridge.py
|
/ok to test ff3705b |
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test 8d66144 |
…mappings The refactor in param_mapping.py renamed GLMExpertGateUpProjMapping to FusedGatedExpertMapping but only added GLMExpertDownProjMapping alias in glm_moe_mappings.py. Add the missing alias so existing bridge imports (glm45_bridge.py, glm_45v_bridge.py) continue to work. Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test 2dfbb99 |
Split multi-name import block into two separate import statements, each with per-line # noqa: F401 comments, to satisfy ruff's import block formatting requirements. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test 3cc9686 |
…ext tests - Set PROVIDER_CLASS = Qwen3NextModelProvider so super().provider_bridge() instantiates the correct provider (not GPTModelProvider which lacks MLA/hybrid fields like q_lora_rank) - Add value is not None guard in hf_config_to_provider_kwargs to skip None-valued config fields - Add null_attr fixture loop in test mocks to suppress Mock() objects for MLA/alternative-expert CONFIG_MAPPING fields Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test 29616bc |
cuichenx
left a comment
There was a problem hiding this comment.
please verify roundtrip and inference outputs of these models
| for m in model: | ||
| m.config.mtp_num_layers = None | ||
| if hasattr(m, "mtp_process"): | ||
| m.mtp_process = False |
There was a problem hiding this comment.
make use of _disable_mtp function from hf_to_megatron_generate_vlm.py?
There was a problem hiding this comment.
yes need to use
# Disable MTP for inference (MTP is only used during training)
def _disable_mtp(m):
"""Disable MTP on a model by clearing mtp_process on the language model."""
m.config.mtp_num_layers = None
inner = m.module if hasattr(m, "module") else m
lang = getattr(inner, "language_model", inner)
if hasattr(lang, "mtp_process"):
lang.mtp_process = False
|
/ok to test a03ca50 |
…export With etp=1 and ep=1, TEGroupedLinear uses explicit_expert_comm=False, so expert weights are stored in [out, in] (PyTorch) rather than [in, out] (TE) layout. The unconditional transpose_on_export=True in GLMExpertDownProjMapping then incorrectly flips the stacked [8, 1024, 512] to [8, 512, 1024], causing torch.allclose to raise a shape mismatch in the GLM-4.5 TP=2 round-trip test. Fix: when hf_state_dict is available, only transpose if the stacked shape doesn't already match the original HF shape but the transposed shape does (same adaptive logic as the old maybe_modify_converted_hf_weight). Fall back to unconditional transpose when no HF reference is available. Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test c336f9c |
Summary
FusedExpertMappingandFusedGatedExpertMappinginparam_mapping.pyto handle many-to-one / one-to-many expert weight conversions generically viais_grouped_export/group_keyprotocolmaybe_modify_converted_hf_weightoverrides andhf_weights_cachefrom GPT-OSS, GLM-4.5, GLM-4.5V, and Qwen3-VL bridges (net -195 lines)_accumulate_grouped_exporttoMegatronModelBridgeand_hf_import_cachefor grouped import, centralizing the expert merge/split logictransformer_layerwithmtp_model_layerand propagatemtp_num_layersfrom HF confighf_to_megatron_generate_text.py: replacemtp_num_layers=None(crashes MTP-enabled models) withm.mtp_process=FalseTest plan
Made with Cursor
Summary by CodeRabbit
New Features
Improvements